package oauth

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"log/slog"
	"net/http"
	"net/url"
	"strings"
	"sync"
	"time"

	"github.com/bluesky-social/indigo/atproto/atclient"
	"github.com/bluesky-social/indigo/atproto/atcrypto"
	"github.com/bluesky-social/indigo/atproto/syntax"

	"github.com/golang-jwt/jwt/v5"
	"github.com/google/go-querystring/query"
)

type PersistSessionCallback = func(ctx context.Context, data *ClientSessionData)

// Persisted information about an OAuth session. Used to resume an active session.
type ClientSessionData struct {
	// Account DID for this session. Assuming only one active session per account, this can be used as "primary key" for storing and retrieving this information.
	AccountDID syntax.DID `json:"account_did"`

	// Identifier to distinguish this particular session for the account. Server backends generally support multiple sessions for the same account. This package will re-use the random 'state' token from the auth flow as the session ID.
	SessionID string `json:"session_id"`

	// Base URL of the "resource server" (eg, PDS). Should include scheme, hostname, port; no path or auth info.
	HostURL string `json:"host_url"`

	// Base URL of the "auth server" (eg, PDS or entryway). Should include scheme, hostname, port; no path or auth info.
	AuthServerURL string `json:"authserver_url"`

	// Full token endpoint
	AuthServerTokenEndpoint string `json:"authserver_token_endpoint"`

	// Full revocation endpoint, if it exists
	AuthServerRevocationEndpoint string `json:"authserver_revocation_endpoint,omitempty"`

	// The set of scopes approved for this session (returned in the initial token request)
	Scopes []string `json:"scopes"`

	// Token which can be used directly against host ("resource server", eg PDS)
	AccessToken string `json:"access_token"`

	// Token which can be sent to auth server (eg, PDS or entryway) to get a new access token
	RefreshToken string `json:"refresh_token"`

	// Current auth server DPoP nonce
	DPoPAuthServerNonce string `json:"dpop_authserver_nonce"`

	// Current host ("resource server", eg PDS) DPoP nonce
	DPoPHostNonce string `json:"dpop_host_nonce"`

	// The secret cryptographic key generated by the client for this specific OAuth session
	DPoPPrivateKeyMultibase string `json:"dpop_privatekey_multibase"`

	// TODO: also persist access token creation time / expiration time? In context that token might not be an easily parsed JWT
}

// Implementation of [atclient.AuthMethod] for an OAuth session. Handles DPoP request token signing and nonce rotation, and token refresh requests. Optionally uses a callback to persist updated session data.
//
// A single ClientSession instance can be called concurrently: updates to session data (the 'Data' field) are protected with a RW mutex lock. Note that concurrent calls to distinct ClientSession instances for the same session could result in clobbered session data.
type ClientSession struct {
	// HTTP client used for token refresh requests
	Client *http.Client

	Config         *ClientConfig
	Data           *ClientSessionData
	DPoPPrivateKey atcrypto.PrivateKey

	PersistSessionCallback PersistSessionCallback

	// Lock which protects concurrent access to session data (eg, access and refresh tokens)
	lk sync.RWMutex
}

// Helper method to handle DPoP retries and client assertions (if the client is confidential)
// body object will be url-encoded (expected to be either RefreshTokenRequest or RevocationRequest)
// expects sess.lk to be held by caller
// if a non-nil *http.Response is returned, the caller is responsible for closing the response body
func (sess *ClientSession) postToAuthServer(ctx context.Context, url string, body interface{}) (*http.Response, error) {
	vals, err := query.Values(body)
	if err != nil {
		return nil, err
	}
	if sess.Config.IsConfidential() {
		clientAssertion, err := sess.Config.NewClientAssertion(sess.Data.AuthServerURL)
		if err != nil {
			return nil, err
		}
		vals.Set("client_assertion_type", ClientAssertionJWTBearer)
		vals.Set("client_assertion", clientAssertion)
	}
	bodyBytes := []byte(vals.Encode())

	var resp *http.Response
	for range 2 {
		dpopJWT, err := NewAuthDPoP("POST", url, sess.Data.DPoPAuthServerNonce, sess.DPoPPrivateKey)
		if err != nil {
			return nil, err
		}

		req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(bodyBytes))
		if err != nil {
			return nil, err
		}
		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
		req.Header.Set("DPoP", dpopJWT)

		resp, err = sess.Client.Do(req)
		if err != nil {
			return nil, err
		}

		// always check if a new DPoP nonce was provided, and proactively update session data (even if there was not an explicit error)
		dpopNonceHdr := resp.Header.Get("DPoP-Nonce")
		if dpopNonceHdr != "" && dpopNonceHdr != sess.Data.DPoPAuthServerNonce {
			sess.Data.DPoPAuthServerNonce = dpopNonceHdr
		}

		// check for an error condition caused by an out of date DPoP nonce
		// note that the HTTP status code is 400 Bad Request on the Auth Server token endpoint, not 401 Unauthorized like it would be on Resource Server requests
		if resp.StatusCode == http.StatusBadRequest && dpopNonceHdr != "" {
			// parseAuthErrorReason() always closes resp.Body
			reason := parseAuthErrorReason(resp, "token-refresh")
			if reason == "use_dpop_nonce" {
				// already updated nonce value above; loop around and try again
				continue
			}
			return nil, fmt.Errorf("auth server request failed (HTTP %d): %s", resp.StatusCode, reason)
		}

		// otherwise process response (success or other error type)
		break
	}

	return resp, nil
}

// Requests new tokens from auth server, and returns the new access token on success.
//
// Internally takes a lock on session data around the entire refresh process, including retries. Persists data using [PersistSessionCallback] if configured.
func (sess *ClientSession) RefreshTokens(ctx context.Context) (string, error) {
	sess.lk.Lock()
	defer sess.lk.Unlock()

	body := RefreshTokenRequest{
		ClientID:     sess.Config.ClientID,
		GrantType:    "refresh_token",
		RefreshToken: sess.Data.RefreshToken,
	}

	resp, err := sess.postToAuthServer(ctx, sess.Data.AuthServerTokenEndpoint, body)
	if err != nil {
		return "", fmt.Errorf("token refresh failed: %w", err)
	}

	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		reason := parseAuthErrorReason(resp, "token-refresh")
		return "", fmt.Errorf("token refresh failed (HTTP %d): %s", resp.StatusCode, reason)
	}

	var tokenResp TokenResponse
	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
		return "", fmt.Errorf("token response failed to decode: %w", err)
	}
	// TODO: more validation of token refresh response?

	sess.Data.AccessToken = tokenResp.AccessToken
	sess.Data.RefreshToken = tokenResp.RefreshToken

	// persist updated data (tokens and possibly nonce)
	if sess.PersistSessionCallback != nil {
		sess.PersistSessionCallback(ctx, sess.Data)
	} else {
		slog.Warn("not saving updated session data", "did", sess.Data.AccountDID, "session_id", sess.Data.SessionID)
	}

	return sess.Data.AccessToken, nil
}

// If supported by the AS, use the revocation endpoint to revoke both the access token and the refresh token.
// This method always succeeds - any errors during revocation are logged but not returned.
func (sess *ClientSession) RevokeSession(ctx context.Context) error {
	sess.lk.Lock()
	defer sess.lk.Unlock()

	if sess.Data.AuthServerRevocationEndpoint == "" {
		return fmt.Errorf("AS does not support token revocation")
	}

	resp, err1 := sess.postToAuthServer(ctx, sess.Data.AuthServerRevocationEndpoint, RevocationRequest{
		ClientID:      sess.Config.ClientID,
		Token:         sess.Data.AccessToken,
		TokenTypeHint: "access_token",
	})
	if err1 != nil {
		err1 = fmt.Errorf("failed revoking access token: %w", err1)
	} else {
		if resp.StatusCode != http.StatusOK {
			err1 = fmt.Errorf("bad HTTP status while revoking access token (%d)", resp.StatusCode)
		}
		resp.Body.Close()
	}

	resp, err2 := sess.postToAuthServer(ctx, sess.Data.AuthServerRevocationEndpoint, RevocationRequest{
		ClientID:      sess.Config.ClientID,
		Token:         sess.Data.RefreshToken,
		TokenTypeHint: "refresh_token",
	})
	if err2 != nil {
		err2 = fmt.Errorf("failed revoking refresh token: %w", err1)
	} else {
		if resp.StatusCode != 200 {
			err2 = fmt.Errorf("bad HTTP status while revoking refresh token (%d)", resp.StatusCode)
		}
		resp.Body.Close()
	}

	return errors.Join(err1, err2) // returns nil if both errors are nil
}

// Constructs and signs a DPoP JWT to include in request header to Host (aka Resource Server, aka PDS). These tokens are different from those used with Auth Server token endpoints (even if the PDS is filling both roles)
func (sess *ClientSession) NewHostDPoP(method, reqURL string) (string, error) {
	sess.lk.RLock()
	defer sess.lk.RUnlock()

	ath := S256CodeChallenge(sess.Data.AccessToken)
	claims := dpopClaims{
		HTTPMethod:      method,
		TargetURI:       reqURL,
		AccessTokenHash: &ath,
		RegisteredClaims: jwt.RegisteredClaims{
			Issuer:    sess.Data.AuthServerURL,
			ID:        secureRandomBase64(16),
			IssuedAt:  jwt.NewNumericDate(time.Now()),
			ExpiresAt: jwt.NewNumericDate(time.Now().Add(jwtExpirationDuration)),
		},
	}
	if sess.Data.DPoPHostNonce != "" {
		claims.Nonce = &sess.Data.DPoPHostNonce
	}

	keyMethod, err := keySigningMethod(sess.DPoPPrivateKey)
	if err != nil {
		return "", err
	}

	// TODO: store a copy of this JWK on the ClientSession as a private field, for efficiency
	pub, err := sess.DPoPPrivateKey.PublicKey()
	if err != nil {
		return "", err
	}
	pubJWK, err := pub.JWK()
	if err != nil {
		return "", err
	}

	token := jwt.NewWithClaims(keyMethod, claims)
	token.Header["typ"] = "dpop+jwt"
	token.Header["jwk"] = pubJWK
	return token.SignedString(sess.DPoPPrivateKey)
}

// copy a request URL and strip query params and fragment, for DPoP
func dpopURL(u *url.URL) string {
	u2 := *u
	u2.RawQuery = ""
	u2.ForceQuery = false
	u2.Fragment = ""
	u2.RawFragment = ""
	return u2.String()
}

// Parses a WWW-Authenticate response header to see if DPoP nonce update is indicated
func isNonceUpdateHeader(hdr string) bool {
	// Example from RFC9449:
	// WWW-Authenticate: DPoP error="use_dpop_nonce", error_description="Resource server requires nonce in DPoP proof"
	return strings.Contains(hdr, "error=\"use_dpop_nonce\"")
}

// Parses a WWW-Authenticate response header to see if access token has expired (needs refresh)
func isExpiredAccessTokenHeader(hdr string) bool {
	// Example from OAuth 2.1 draft:
	// WWW-Authenticate: Bearer error="invalid_token" error_description="The access token expired"
	// TODO: should this also look for "expired"?
	return strings.Contains(hdr, "error=\"invalid_token\"")
}

func (sess *ClientSession) GetHostAccessData() (accessToken string, dpopHostNonce string) {
	sess.lk.RLock()
	defer sess.lk.RUnlock()

	return sess.Data.AccessToken, sess.Data.DPoPHostNonce
}

func (sess *ClientSession) UpdateHostDPoPNonce(ctx context.Context, nonce string) {
	sess.lk.Lock()
	defer sess.lk.Unlock()

	sess.Data.DPoPHostNonce = nonce

	if sess.PersistSessionCallback != nil {
		sess.PersistSessionCallback(ctx, sess.Data)
	} else {
		slog.Warn("not saving updated host DPoP nonce", "did", sess.Data.AccountDID, "session_id", sess.Data.SessionID)
	}
}

// Sends API request to OAuth Resource Server (PDS), using access token and DPoP.
//
// Automatically handles DPoP nonce updates and token refresh as needed, based on the response status code and `WWW-Authenticate` header.
func (sess *ClientSession) DoWithAuth(c *http.Client, req *http.Request, endpoint syntax.NSID) (*http.Response, error) {

	durl := dpopURL(req.URL)

	accessToken, dpopNonce := sess.GetHostAccessData()

	// this method may need to retry twice, once for DPoP nonce update and once for token refresh
	var resp *http.Response
	for range 3 {
		dpopJWT, err := sess.NewHostDPoP(req.Method, durl)
		if err != nil {
			return nil, err
		}
		req.Header.Set("Authorization", fmt.Sprintf("DPoP %s", accessToken))
		req.Header.Set("DPoP", dpopJWT)

		resp, err = c.Do(req)
		if err != nil {
			return nil, err
		}

		// on Success, or many types of error, just return HTTP response
		// "Unauthorized" is HTTP status code 401
		if resp.StatusCode != http.StatusUnauthorized || resp.Header.Get("WWW-Authenticate") == "" {
			return resp, nil
		}

		authHdr := resp.Header.Get("WWW-Authenticate")
		dpopNonceHdr := resp.Header.Get("DPoP-Nonce")

		// if DPoP nonce changed, update and retry request
		if isNonceUpdateHeader(authHdr) && dpopNonceHdr != "" {
			// TODO: validate or normalize dpopNonceHdr in some way? eg minimum length
			if dpopNonceHdr == dpopNonce {
				return nil, fmt.Errorf("OAuth PDS DPoP nonce failure, but no new nonce supplied")
			}

			// persist new nonce value via callback
			sess.UpdateHostDPoPNonce(req.Context(), dpopNonceHdr)
			dpopNonce = dpopNonceHdr

			// retry request
			retry := req.Clone(req.Context())
			if req.GetBody != nil {
				retry.Body, err = req.GetBody()
				if err != nil {
					return nil, fmt.Errorf("GetBody failed when retrying API request: %w", err)
				}
			}
			req = retry
			continue
		}

		// if access token expired, refresh and retry
		if isExpiredAccessTokenHeader(authHdr) {
			accessToken, err = sess.RefreshTokens(req.Context())
			if err != nil {
				return nil, fmt.Errorf("failed to refresh OAuth tokens: %w", err)
			}

			retry := req.Clone(req.Context())
			if req.GetBody != nil {
				retry.Body, err = req.GetBody()
				if err != nil {
					return nil, fmt.Errorf("GetBody failed when retrying API request: %w", err)
				}
			}
			req = retry
			continue
		}

		// otherwise, this was some other type of auth failure; just return the full response
		// NOTE: in theory we could return an APIError here instead
		return resp, nil
	}

	return nil, fmt.Errorf("OAuth client ran out of request retries")
}

// Creates a new [atclient.APIClient] which wraps this session for auth.
func (sess *ClientSession) APIClient() *atclient.APIClient {
	c := atclient.APIClient{
		Client:     sess.Client,
		Host:       sess.Data.HostURL,
		Auth:       sess,
		AccountDID: &sess.Data.AccountDID,
	}
	if sess.Config.UserAgent != "" {
		c.Headers = make(map[string][]string)
		c.Headers.Set("User-Agent", sess.Config.UserAgent)
	}
	return &c
}
